{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A Custom Estimator for a Convolutional Neural Network\n",
    "\n",
    "Hello intrepid reader! In this notebook, we'll will add a function that uses ```tf.layers``` to build a vanilla CNN. This should achieve around 99% accuracy on MNIST (there is still plenty of room to improve). Have a look at the ```build_cnn``` function where we define the model. Aside from that (and changing our preprocessing to no longer 'flatten' the images, and to add a color channel dimension), the code otherwise remains unchanged. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import math\n",
    "import numpy as np\n",
    "\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Import the dataset. Here, we'll need to convert the labels to a one-hot encoding, and we'll reshape the MNIST images to (784,)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# We'll use Keras (included with TensorFlow) to import the data\n",
    "# I figured I'd do all the preprocessing and reshaping here, \n",
    "# rather than in the model.\n",
    "(x_train, y_train), (x_test, y_test) = tf.contrib.keras.datasets.mnist.load_data()\n",
    "\n",
    "x_train = x_train.astype('float32')\n",
    "x_test = x_test.astype('float32')\n",
    "\n",
    "y_train = y_train.astype('int32')\n",
    "y_test = y_test.astype('int32')\n",
    "\n",
    "# Normalize the color values to 0-1\n",
    "# (as imported, they're 0-255)\n",
    "x_train /= 255\n",
    "x_test /= 255\n",
    "\n",
    "# The CNN we'll use later expects a color channel dimension\n",
    "# Let's add this here\n",
    "x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)\n",
    "x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)\n",
    "\n",
    "# Convert to one-hot.\n",
    "y_train = tf.contrib.keras.utils.to_categorical(y_train, num_classes=10)\n",
    "y_test = tf.contrib.keras.utils.to_categorical(y_test, num_classes=10)\n",
    "\n",
    "print(x_train.shape[0], 'train samples')\n",
    "print(x_test.shape[0], 'test samples')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This function that defines our CNN."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def build_cnn(features, mode):\n",
    "    \n",
    "    image_batch = features['x']\n",
    "    \n",
    "    with tf.name_scope(\"conv1\"):  \n",
    "        conv1 = tf.layers.conv2d(inputs=image_batch, filters=32, kernel_size=[3, 3],\n",
    "                                 padding='same', activation=tf.nn.relu)\n",
    "\n",
    "    with tf.name_scope(\"pool1\"):  \n",
    "        pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)\n",
    "\n",
    "    with tf.name_scope(\"conv2\"):  \n",
    "        conv2 = tf.layers.conv2d(inputs=pool1, filters=64, kernel_size=[3, 3],\n",
    "                                 padding='same', activation=tf.nn.relu)\n",
    "\n",
    "    with tf.name_scope(\"pool2\"):  \n",
    "        pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)\n",
    "\n",
    "    with tf.name_scope(\"dense\"):  \n",
    "        # The 'images' are now 7x7 (28 / 2 / 2), and we have 64 channels per image\n",
    "        pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])\n",
    "        dense = tf.layers.dense(inputs=pool2_flat, units=128, activation=tf.nn.relu)\n",
    "\n",
    "    with tf.name_scope(\"dropout\"):  \n",
    "        # Add dropout operation; 0.8 probability that a neuron will be kept\n",
    "        dropout = tf.layers.dropout(\n",
    "            inputs=dense, rate=0.2, training = mode == tf.estimator.ModeKeys.TRAIN)\n",
    "\n",
    "    logits = tf.layers.dense(inputs=dropout, units=10)\n",
    "\n",
    "    return logits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To write a Custom Estimator we'll specify our own model function. Here, we'll use ```tf.layers``` to replicate the model from the third notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def model_fn(features, labels, mode):\n",
    "    \n",
    "    logits = build_cnn(features, mode)\n",
    "    \n",
    "    # Generate Predictions\n",
    "    classes = tf.argmax(logits, axis=1)\n",
    "    predictions = {\n",
    "        'classes': classes,\n",
    "        'probabilities': tf.nn.softmax(logits, name='softmax_tensor')\n",
    "    }\n",
    "    \n",
    "    if mode == tf.estimator.ModeKeys.PREDICT:\n",
    "        # Return an EstimatorSpec for prediction\n",
    "        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)\n",
    "        \n",
    "    # Compute the loss, per usual.\n",
    "    loss = tf.losses.softmax_cross_entropy(\n",
    "        onehot_labels=labels, logits=logits)\n",
    "        \n",
    "    if mode == tf.estimator.ModeKeys.TRAIN:\n",
    "        \n",
    "        # Configure the Training Op\n",
    "        train_op = tf.contrib.layers.optimize_loss(\n",
    "            loss=loss,\n",
    "            global_step=tf.train.get_global_step(),\n",
    "            learning_rate=1e-3,\n",
    "            optimizer='Adam')\n",
    "\n",
    "        # Return an EstimatorSpec for training\n",
    "        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions,\n",
    "                                      loss=loss, train_op=train_op)    \n",
    "\n",
    "    assert mode == tf.estimator.ModeKeys.EVAL\n",
    "    \n",
    "    # Configure the accuracy metric for evaluation\n",
    "    metrics = {'accuracy': tf.metrics.accuracy(classes, tf.argmax(labels, axis=1))}\n",
    "    \n",
    "    return tf.estimator.EstimatorSpec(mode=mode, \n",
    "                                      predictions=predictions, \n",
    "                                      loss=loss,\n",
    "                                      eval_metric_ops=metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Input functions, as before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_input = tf.estimator.inputs.numpy_input_fn(\n",
    "    {'x': x_train},\n",
    "    y_train, \n",
    "    num_epochs=None, # repeat forever\n",
    "    shuffle=True # \n",
    ")\n",
    "\n",
    "test_input = tf.estimator.inputs.numpy_input_fn(\n",
    "    {'x': x_test},\n",
    "    y_test,\n",
    "    num_epochs=1, # loop through the dataset once\n",
    "    shuffle=False # don't shuffle the test data\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "estimator = tf.estimator.Estimator(model_fn=model_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# If you are running on a machine without a GPU, this can take some time to train.\n",
    "estimator.train(input_fn=train_input, steps=2000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Evaluate the estimator using our input function.\n",
    "# We should see our accuracy metric below\n",
    "# Tweaking with the params of the model, you can get >99% accuracy\n",
    "evaluation = estimator.evaluate(input_fn=test_input)\n",
    "print(evaluation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Here's how to print predictions on a few examples\n",
    "MAX_TO_PRINT = 5\n",
    "\n",
    "# This returns a generator object\n",
    "predictions = estimator.predict(input_fn=test_input)\n",
    "i = 0\n",
    "for p in predictions:\n",
    "    true_label = np.argmax(y_test[i])\n",
    "    predicted_label = p['classes']\n",
    "    print(\"Example %d. True: %d, Predicted: %s\" % (i, true_label, predicted_label))\n",
    "    i += 1\n",
    "    if i == MAX_TO_PRINT: break"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
